{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Recurrent Neural Nets - Sentiment Analysis\n",
    "\n",
    "There is a branch of Deep Learning that is dedicated to processing time series. These deep Nets are **Recursive Neural Nets (RNNs)**. LSTMs are one of the few types of RNNs that are available. Gated Recurent Units (GRUs) are the other type of popular RNNs.\n",
    "\n",
    "This is an illustration from http://colah.github.io/posts/2015-08-Understanding-LSTMs/ (A highly recommended read)\n",
    "\n",
    "![RNNs](../images/RNN-unrolled.png)\n",
    "\n",
    "In lesson 5 we looked at getting the sentiment of a given movie review. The data comes from a IMDB review set where a rating of less than 5 was classified as negative and greater than 5 as positive. Neutral reviews were ignored.\n",
    "\n",
    "In the previous lesson we considered a Bag of Words (BoW) model where the emphasis is on how many times a particular word appeared in the sentence/ review. This worked fairly well giving around 80% accuracy.\n",
    "\n",
    "One thing that was missing was the structure of the sentence. The word order was not taken into account. For example a sentence such as this: \"I wanted to hate it so much but I loved the movie.\" would probably confuse the previous model. Simply because 'love' and 'hate' are both included in the sentence. There were other preprocessing steps done such as stemming (eg. hated -> hate, runner -> run) which will not be done in this lesson. We will maintain structure and feed in a number representation of words in order into our DL model.\n",
    "\n",
    "<img src=\"../images/happy_trump.png\" alt=\"happy\" style=\"width: 150px;\"/><img src=\"../images/sad_trump.png\" alt=\"sad\" style=\"width: 150px;\"/>\n",
    "\n",
    "## References\n",
    "1. Similar work done for a fake news classifier: https://blog.kjamistan.com/comparing-scikit-learn-text-classifiers-on-a-fake-news-dataset/\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/ywinX5wgdEU?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.display import HTML\n",
    "HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/ywinX5wgdEU?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import re\n",
    "\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization, LSTM, Embedding, Reshape\n",
    "from keras.models import load_model, model_from_json\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import os\n",
    "import urllib\n",
    "\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "!mkdir -p data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "if not os.path.isfile('data/reviews2.pkl'):\n",
    "    urllib.request.urlretrieve('https://www.dropbox.com/s/15tfttuzqe7fimg/reviews2.pkl?dl=1','data/reviews2.pkl') "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Preprocessing steps: lower case, remove urls, some punctuations etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(25000, 2)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Reviews</th>\n",
       "      <th>Sentiment</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>bromwell high is a cartoon comedy . it ran at ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>homelessness or houselessness as george carlin...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>brilliant overacting by lesley ann warren . be...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>this is easily the most underrated film inn th...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>this is not the typical mel brooks film . it w...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             Reviews  Sentiment\n",
       "0  bromwell high is a cartoon comedy . it ran at ...          1\n",
       "1  homelessness or houselessness as george carlin...          1\n",
       "2  brilliant overacting by lesley ann warren . be...          1\n",
       "3  this is easily the most underrated film inn th...          1\n",
       "4  this is not the typical mel brooks film . it w...          1"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_pickle('data/reviews2.pkl')\n",
    "df.Reviews = df.Reviews.str.lower()\n",
    "df.Reviews = df.Reviews.str.replace(r'http[\\w:/\\.]+','') # remove urls\n",
    "df.Reviews = df.Reviews.str.replace(r'[^\\.\\w\\s]','') #remove everything but characters and punctuation\n",
    "df.Reviews = df.Reviews.str.replace(r'\\.\\.+','.') #replace multple periods with a single one\n",
    "df.Reviews = df.Reviews.str.replace(r'\\.',' .') #replace multple periods with a single one\n",
    "df.Reviews = df.Reviews.str.replace(r'\\s\\s+',' ') #replace multple white space with a single one\n",
    "df.Reviews = df.Reviews.str.strip() \n",
    "print(df.shape)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'bromwell high is a cartoon comedy . it ran at the same time as some other programs about school life such as teachers . my 35 years in the teaching profession lead me to believe that bromwell highs satire is much closer to reality than is teachers . the scramble to survive financially the insightful students who can see right through their pathetic teachers pomp the pettiness of the whole situation all remind me of the schools i knew and their students . when i saw the episode in which a student repeatedly tried to burn down the school i immediately recalled . at . high . a classic line inspector im here to sack one of your teachers . student welcome to bromwell high . i expect that many adults of my age think that bromwell high is far fetched . what a pity that it isnt'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.Reviews.values[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get all the unique words. We will only consider words that have been used more than 5 times. Finally from this we create a dictionary mapping words to integers.\n",
    "\n",
    "Once this is done we will create a list of reviews where the words are converted to ints."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The number of unique words are:  27915\n",
      "The first review looks like this: \n",
      "[22054, 323, 6, 3, 1074, 217, 1, 8, 2102, 32, 0, 167, 59, 14, 47, 81, 5531, 43, 400, 118]\n"
     ]
    }
   ],
   "source": [
    "all_text = ' '.join(df.Reviews.values)\n",
    "words = all_text.split()\n",
    "u_words = Counter(words).most_common()\n",
    "u_words = [word[0] for word in u_words if word[1]>5] # we will only consider words that have been used more than 5 times\n",
    "# create the dictionary\n",
    "word2num = dict(zip(u_words,range(len(u_words))))\n",
    "word2num['<Other>'] = len(u_words)\n",
    "num2word = dict(zip(word2num.values(), word2num.keys()))\n",
    "\n",
    "int_text = [[word2num[word] if word in word2num else len(u_words) for word in Review.split()] for Review in df.Reviews.values]\n",
    "\n",
    "print('The number of unique words are: ', len(u_words))\n",
    "print('The first review looks like this: ')\n",
    "print(int_text[0][:20])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAD8CAYAAACRkhiPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAE0hJREFUeJzt3X+oX/d93/Hnq3KimTRe7fpOCEmeFBAtsiFOdNE0GsJW\n01qJR+X9YxToLIawBvbaFDaGtMLW/SFwByur2WzQkszSlkXT2gaLJm5RtJQymONep05kyVGtxDbS\nRb+arqjtQJ3V9/74fjx/d32v7/dKV/erez/PBxy+n+/7nM/R+eggve75nPP93lQVkqQ+/ci4D0CS\nND6GgCR1zBCQpI4ZApLUMUNAkjpmCEhSxwwBSeqYISBJHTMEJKljd4z7AOZz77331saNG8d9GJK0\nrLzyyit/XFUT821324fAxo0bmZqaGvdhSNKykuTtUbZzOkiSOmYISFLHDAFJ6pghIEkdMwQkqWOG\ngCR1zBCQpI4ZApLUMUNAkjo27yeGk/wE8F+HSh8D/gVwuNU3Am8Bj1XV/2p99gN7gOvAL1bV77b6\nVuB54E7g68Dn6zb6Tfcb931t1vpbTz+yxEciSUtj3iuBqjpTVQ9W1YPAVuB/A18F9gEnqmozcKK9\nJ8kWYBdwP7ADeDbJqra754AngM1t2bG4w5EkLcRCp4MeAr5fVW8DO4FDrX4IeLS1dwJHqupaVb0J\nnAW2JVkL3FVVL7Wf/g8P9ZEkjcFCQ2AX8JXWXlNVF1r7IrCmtdcB54b6nG+1da09sy5JGpORQyDJ\nh4GfA/7bzHXtJ/tFm9tPsjfJVJKpK1euLNZuJUkzLORK4DPAt6vqUnt/qU3x0F4vt/o0sGGo3/pW\nm27tmfX3qaqDVTVZVZMTE/N+HbYk6QYtJAQ+x3tTQQDHgN2tvRt4Yai+K8nqJJsY3AB+uU0dXU2y\nPUmAx4f6SJLGYKRfKpPkI8DPAP9oqPw0cDTJHuBt4DGAqjqV5ChwGngHeKqqrrc+T/LeI6IvtkWS\nNCYjhUBV/QXw4zNqP2TwtNBs2x8ADsxSnwIeWPhhSpJuBT8xLEkdMwQkqWOGgCR1zBCQpI4ZApLU\nMUNAkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdcwQkKSOGQKS1DFDQJI6ZghIUscMAUnqmCEgSR0z\nBCSpY4aAJHXMEJCkjo0UAkl+LMlvJPlekteT/O0k9yQ5nuSN9nr30Pb7k5xNcibJw0P1rUlOtnXP\nJMmtGJQkaTSjXgn8OvA7VfWTwMeB14F9wImq2gycaO9JsgXYBdwP7ACeTbKq7ec54Algc1t2LNI4\nJEk3YN4QSPLXgU8DXwSoqr+sqj8FdgKH2maHgEdbeydwpKquVdWbwFlgW5K1wF1V9VJVFXB4qI8k\naQxGuRLYBFwB/mOSP0zyhSQfAdZU1YW2zUVgTWuvA84N9T/fautae2ZdkjQmo4TAHcAngeeq6hPA\nX9Cmft7VfrKvxTqoJHuTTCWZunLlymLtVpI0wyghcB44X1Xfau9/g0EoXGpTPLTXy239NLBhqP/6\nVptu7Zn196mqg1U1WVWTExMTo45FkrRA84ZAVV0EziX5iVZ6CDgNHAN2t9pu4IXWPgbsSrI6ySYG\nN4BfblNHV5Nsb08FPT7UR5I0BneMuN0vAF9O8mHgB8A/ZBAgR5PsAd4GHgOoqlNJjjIIineAp6rq\netvPk8DzwJ3Ai22RJI3JSCFQVa8Ck7OsemiO7Q8AB2apTwEPLOQAJUm3jp8YlqSOGQKS1DFDQJI6\nZghIUscMAUnqmCEgSR0zBCSpY4aAJHXMEJCkjhkCktQxQ0CSOmYISFLHDAFJ6pghIEkdMwQkqWOG\ngCR1zBCQpI4ZApLUMUNAkjpmCEhSx0YKgSRvJTmZ5NUkU612T5LjSd5or3cPbb8/ydkkZ5I8PFTf\n2vZzNskzSbL4Q5IkjWohVwJ/t6oerKrJ9n4fcKKqNgMn2nuSbAF2AfcDO4Bnk6xqfZ4DngA2t2XH\nzQ9BknSjbmY6aCdwqLUPAY8O1Y9U1bWqehM4C2xLsha4q6peqqoCDg/1kSSNwaghUMA3krySZG+r\nramqC619EVjT2uuAc0N9z7fautaeWZckjckdI273qaqaTvI3gONJvje8sqoqSS3WQbWg2Qtw3333\nLdZuJUkzjHQlUFXT7fUy8FVgG3CpTfHQXi+3zaeBDUPd17fadGvPrM/25x2sqsmqmpyYmBh9NJKk\nBZk3BJJ8JMlH320DPwu8BhwDdrfNdgMvtPYxYFeS1Uk2MbgB/HKbOrqaZHt7KujxoT6SpDEYZTpo\nDfDV9jTnHcB/qarfSfIHwNEke4C3gccAqupUkqPAaeAd4Kmqut729STwPHAn8GJbJEljMm8IVNUP\ngI/PUv8h8NAcfQ4AB2apTwEPLPwwJUm3gp8YlqSOGQKS1DFDQJI6ZghIUscMAUnqmCEgSR0zBCSp\nY4aAJHXMEJCkjhkCktQxQ0CSOmYISFLHDAFJ6tiov1msaxv3fW3W+ltPP7LERyJJi8srAUnqmCEg\nSR0zBCSpY4aAJHXMEJCkjhkCktSxkUMgyaokf5jkt9v7e5IcT/JGe717aNv9Sc4mOZPk4aH61iQn\n27pnkmRxhyNJWoiFXAl8Hnh96P0+4ERVbQZOtPck2QLsAu4HdgDPJlnV+jwHPAFsbsuOmzp6SdJN\nGSkEkqwHHgG+MFTeCRxq7UPAo0P1I1V1rareBM4C25KsBe6qqpeqqoDDQ30kSWMw6pXAvwX+GfBX\nQ7U1VXWhtS8Ca1p7HXBuaLvzrbautWfWJUljMm8IJPl7wOWqemWubdpP9rVYB5Vkb5KpJFNXrlxZ\nrN1KkmYY5Urgp4CfS/IWcAT46ST/GbjUpnhor5fb9tPAhqH+61tturVn1t+nqg5W1WRVTU5MTCxg\nOJKkhZg3BKpqf1Wtr6qNDG74/veq+nngGLC7bbYbeKG1jwG7kqxOsonBDeCX29TR1STb21NBjw/1\nkSSNwc18i+jTwNEke4C3gccAqupUkqPAaeAd4Kmqut76PAk8D9wJvNgWSdKYLCgEqur3gN9r7R8C\nD82x3QHgwCz1KeCBhR6kJOnW8BPDktQxQ0CSOmYISFLHDAFJ6pghIEkdMwQkqWOGgCR1zBCQpI4Z\nApLUMUNAkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdcwQkKSOGQKS1DFDQJI6ZghIUscMAUnq2Lwh\nkOSvJXk5yXeSnEryr1r9niTHk7zRXu8e6rM/ydkkZ5I8PFTfmuRkW/dMktyaYUmSRjHKlcA14Ker\n6uPAg8COJNuBfcCJqtoMnGjvSbIF2AXcD+wAnk2yqu3rOeAJYHNbdiziWCRJCzRvCNTAn7e3H2pL\nATuBQ61+CHi0tXcCR6rqWlW9CZwFtiVZC9xVVS9VVQGHh/pIksZgpHsCSVYleRW4DByvqm8Ba6rq\nQtvkIrCmtdcB54a6n2+1da09sy5JGpORQqCqrlfVg8B6Bj/VPzBjfTG4OlgUSfYmmUoydeXKlcXa\nrSRphgU9HVRVfwp8k8Fc/qU2xUN7vdw2mwY2DHVb32rTrT2zPtufc7CqJqtqcmJiYiGHKElagFGe\nDppI8mOtfSfwM8D3gGPA7rbZbuCF1j4G7EqyOskmBjeAX25TR1eTbG9PBT0+1EeSNAZ3jLDNWuBQ\ne8LnR4CjVfXbSf4ncDTJHuBt4DGAqjqV5ChwGngHeKqqrrd9PQk8D9wJvNgWSdKYzBsCVfVd4BOz\n1H8IPDRHnwPAgVnqU8AD7+8hSRoHPzEsSR0zBCSpY4aAJHXMEJCkjhkCktQxQ0CSOjbK5wQ0h437\nvjZr/a2nH1niI5GkG9NlCMz1n7ck9cbpIEnqmCEgSR0zBCSpY4aAJHXMEJCkjhkCktQxQ0CSOmYI\nSFLHDAFJ6pghIEkdMwQkqWOGgCR1bN4QSLIhyTeTnE5yKsnnW/2eJMeTvNFe7x7qsz/J2SRnkjw8\nVN+a5GRb90yS3JphSZJGMcqVwDvAP6mqLcB24KkkW4B9wImq2gycaO9p63YB9wM7gGeTrGr7eg54\nAtjclh2LOBZJ0gLNGwJVdaGqvt3afwa8DqwDdgKH2maHgEdbeydwpKquVdWbwFlgW5K1wF1V9VJV\nFXB4qI8kaQwWdE8gyUbgE8C3gDVVdaGtugisae11wLmhbudbbV1rz6xLksZk5BBI8qPAbwK/VFVX\nh9e1n+xrsQ4qyd4kU0mmrly5sli7lSTNMFIIJPkQgwD4clX9VitfalM8tNfLrT4NbBjqvr7Vplt7\nZv19qupgVU1W1eTExMSoY5EkLdAoTwcF+CLwelX92tCqY8Du1t4NvDBU35VkdZJNDG4Av9ymjq4m\n2d72+fhQH0nSGIzyO4Z/CvgHwMkkr7baPweeBo4m2QO8DTwGUFWnkhwFTjN4suipqrre+j0JPA/c\nCbzYFknSmMwbAlX1P4C5nud/aI4+B4ADs9SngAcWcoCSpFvHTwxLUscMAUnq2Cj3BLRAG/d9bdb6\nW08/ssRHIkkfzCsBSeqYISBJHTMEJKljhoAkdcwQkKSOGQKS1DFDQJI6ZghIUscMAUnqmCEgSR0z\nBCSpY4aAJHXML5BbQnN9sRz45XKSxsMrAUnqmCEgSR0zBCSpY4aAJHVs3hBI8qUkl5O8NlS7J8nx\nJG+017uH1u1PcjbJmSQPD9W3JjnZ1j2TZK5fXi9JWiKjXAk8D+yYUdsHnKiqzcCJ9p4kW4BdwP2t\nz7NJVrU+zwFPAJvbMnOfkqQlNm8IVNXvA38yo7wTONTah4BHh+pHqupaVb0JnAW2JVkL3FVVL1VV\nAYeH+kiSxuRG7wmsqaoLrX0RWNPa64BzQ9udb7V1rT2zLkkao5v+sFhVVZJajIN5V5K9wF6A++67\n74b380EfzpIk3XgIXEqytqoutKmey60+DWwY2m59q0239sz6rKrqIHAQYHJyclED5nY1V2D5SWJJ\nt9KNTgcdA3a39m7ghaH6riSrk2xicAP45TZ1dDXJ9vZU0ONDfSRJYzLvlUCSrwB/B7g3yXngXwJP\nA0eT7AHeBh4DqKpTSY4Cp4F3gKeq6nrb1ZMMnjS6E3ixLZKkMZo3BKrqc3OsemiO7Q8AB2apTwEP\nLOjoJEm3lJ8YlqSOGQKS1DFDQJI65i+Vuc356KikW8krAUnqmCEgSR0zBCSpY94TWKa8VyBpMXgl\nIEkdMwQkqWOGgCR1zBCQpI55Y3iF8YaxpIXwSkCSOuaVQCe8QpA0G68EJKljhoAkdczpoM45TST1\nzRDQrAwHqQ+GgBbEcJBWliUPgSQ7gF8HVgFfqKqnl/oYtPjmCoe53EhoGEDS4lvSG8NJVgH/HvgM\nsAX4XJItS3kMkqT3LPWVwDbgbFX9ACDJEWAncHqJj0NjttArB0m3xlKHwDrg3ND788DfWuJj0Arj\nNJF0427LG8NJ9gJ729s/T3LmBnZzL/DHi3dUtzXHOov86i0+klvP87oyLdVY/+YoGy11CEwDG4be\nr2+1/09VHQQO3swflGSqqiZvZh/LhWNdmRzrynS7jXWpPzH8B8DmJJuSfBjYBRxb4mOQJDVLeiVQ\nVe8k+cfA7zJ4RPRLVXVqKY9BkvSeJb8nUFVfB76+BH/UTU0nLTOOdWVyrCvTbTXWVNW4j0GSNCZ+\ni6gkdWzFhUCSHUnOJDmbZN+4j2cxJHkryckkryaZarV7khxP8kZ7vXto+/1t/GeSPDy+I59fki8l\nuZzktaHagseWZGv7Ozqb5JkkWeqxzGeOsf5Kkul2bl9N8tmhdct5rBuSfDPJ6SSnkny+1Vfcuf2A\nsS6Pc1tVK2ZhcLP5+8DHgA8D3wG2jPu4FmFcbwH3zqj9a2Bfa+8DfrW1t7RxrwY2tb+PVeMewweM\n7dPAJ4HXbmZswMvAdiDAi8Bnxj22Ecf6K8A/nWXb5T7WtcAnW/ujwB+1Ma24c/sBY10W53alXQn8\nv6+lqKq/BN79WoqVaCdwqLUPAY8O1Y9U1bWqehM4y+Dv5bZUVb8P/MmM8oLGlmQtcFdVvVSDf0mH\nh/rcNuYY61yW+1gvVNW3W/vPgNcZfGPAiju3HzDWudxWY11pITDb11J80MlYLgr4RpJX2qepAdZU\n1YXWvgisae2V8Hew0LGta+2Z9eXiF5J8t00XvTs9smLGmmQj8AngW6zwcztjrLAMzu1KC4GV6lNV\n9SCDb199Ksmnh1e2nxpW5GNeK3lszXMMpi8fBC4A/2a8h7O4kvwo8JvAL1XV1eF1K+3czjLWZXFu\nV1oIjPS1FMtNVU2318vAVxlM71xql4+018tt85Xwd7DQsU239sz6ba+qLlXV9ar6K+A/8N7U3bIf\na5IPMfhP8ctV9VutvCLP7WxjXS7ndqWFwIr7WookH0ny0XfbwM8CrzEY1+622W7ghdY+BuxKsjrJ\nJmAzg5tNy8mCxtamF64m2d6epnh8qM9t7d3/EJu/z+DcwjIfazu2LwKvV9WvDa1aced2rrEum3M7\n7jvri70An2Vwd/77wC+P+3gWYTwfY/AkwXeAU++OCfhx4ATwBvAN4J6hPr/cxn+G2+xJilnG9xUG\nl8r/h8Ec6J4bGRswyeAf2feBf0f7IOTttMwx1v8EnAS+y+A/h7UrZKyfYjDV813g1bZ8diWe2w8Y\n67I4t35iWJI6ttKmgyRJC2AISFLHDAFJ6pghIEkdMwQkqWOGgCR1zBCQpI4ZApLUsf8LzF9DIxRe\nnRAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f57fc3083c8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist([len(t) for t in int_text],50)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The number of reviews greater than 500 in length is:  2234\n",
      "The number of reviews less than 50 in length is:  411\n"
     ]
    }
   ],
   "source": [
    "print('The number of reviews greater than 500 in length is: ', np.sum(np.array([len(t)>500 for t in int_text])))\n",
    "print('The number of reviews less than 50 in length is: ', np.sum(np.array([len(t)<50 for t in int_text])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You cannot pass differing lengths of sentences to the algorithm. Hence we shall prepad the sentence with `<PAD>`. Sequences less than 500 in length will be prepadded and sequences that are longer than 500 will be truncated. It is assumed that the sentiment of the review can be asserted from the first 500 words."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num2word[len(word2num)] = '<PAD>'\n",
    "word2num['<PAD>'] = len(word2num)\n",
    "\n",
    "for i, t in enumerate(int_text):\n",
    "    if len(t)<500:\n",
    "        int_text[i] = [word2num['<PAD>']]*(500-len(t)) + t\n",
    "    elif len(t)>500:\n",
    "        int_text[i] = t[:500]\n",
    "    else:\n",
    "        continue\n",
    "\n",
    "x = np.array(int_text)\n",
    "y = df.Sentiment.values\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Many to One LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_1 (Embedding)      (None, None, 32)          893344    \n",
      "_________________________________________________________________\n",
      "lstm_1 (LSTM)                (None, 64)                24832     \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 1)                 65        \n",
      "=================================================================\n",
      "Total params: 918,241.0\n",
      "Trainable params: 918,241\n",
      "Non-trainable params: 0.0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = Sequential()\n",
    "model.add(Embedding(len(word2num), 32)) # , batch_size=batch_size\n",
    "model.add(LSTM(64))\n",
    "model.add(Dense(1, activation='sigmoid'))\n",
    "\n",
    "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n",
      "22500/22500 [==============================] - 145s - loss: 0.5698 - acc: 0.7091   \n",
      "Epoch 2/2\n",
      "22500/22500 [==============================] - 143s - loss: 0.2964 - acc: 0.8843   \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f577cdf1ac8>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_size = 128\n",
    "model.fit(X_train, y_train, batch_size=batch_size, epochs=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "The first number below is the loss and the second value onwards is whatever metrics you have used, in this case it is the accuracy. Important to set the `batch_size` as well since evaluation can be slower for default size (32)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2500/2500 [==============================] - 5s     \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.33948203306198121, 0.85440000076293943]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(X_test, y_test, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([27916, 27916, 27916, 27916, 27916, 27916, 27916, 27916, 27916,\n",
       "       27916, 27916, 27916, 27916, 27916, 27916,     7,  3848,     7,\n",
       "           0,  1881,   726,  3247,   289,  1246,  1210,     1,  2964,\n",
       "        1030,    37,    44,    25,  1283, 22951,  7472,    16,     3,\n",
       "        3590,    11,  1320,  3220,     1,    14,  1036,   717,  1709,\n",
       "        2849,   740,    46,     0,   230,  1103,    36,   105,  3531,\n",
       "           7,     0,  4668,     1,   176,    22,     0,  1842,   129,\n",
       "           0,   105,     4,    94,   399,  8951,  5448,     7,    34,\n",
       "         761,     5,   162,     0,  1773,     2,   250,   321,  3894,\n",
       "           5,     0,  3543,    19,    12,    14,    13,  3304,    46,\n",
       "          33,   100,   595,     4,    82,    15,    47,   272,   172,\n",
       "       13418,  7067,   117,    22,  1379,     5,   175,     0,   390,\n",
       "         514,    16,   250,    11,   918,    39,   496,    36, 14725,\n",
       "           1,   252,    75,     5,   156,  5102,   115,     0,    64,\n",
       "         299,   267,     9,    69,  1236, 19937,     5,    11,     1,\n",
       "         726,  3247,     2,    25,   318,  2165, 12988, 17634,    69,\n",
       "          27, 13040,    84,    14,     3,  2964,  1030,    28,    98,\n",
       "          27,  4903,    48,    18,    15,   144,    11,    24,  5774,\n",
       "           8,   146,   159,    19,    12,    18,    16,  3220,    46,\n",
       "          48,     2,     0,  1612,   109,   803,     3,   981,     0,\n",
       "         156,  5102, 26764,   501,     3, 21642,     4,     0,  9579,\n",
       "        5444,     1,     0,   357,   230,   368,   128,     3,  1673,\n",
       "          36,  2333,     2,    28,    68,  5444,  5693,     1,     3,\n",
       "         156,  5102,  2076,   230,   251,    33,   709,  5527,  2954,\n",
       "           5,     3,   219,   735,     4, 22295,  7777,     2,   501,\n",
       "       21605,     3,  3238,     4,     0,   357,   230,    37,     6,\n",
       "        1815,     5,   120,  3220, 16539,     1,  5527,   183,  5181,\n",
       "         827,     7,     0,    86,    20,   115,    28,  1985,  1459,\n",
       "          19,    12,  1236,    56,   500,     7,    10,    20,     6,\n",
       "        1491,  9453,     1,    10,   556,   203,   155,  9453,    13,\n",
       "       22669,     2,   185,    11,    59,    28,    91,     3, 13204,\n",
       "           4,   385,    60,    27,  1247,   458,    37,    24,    63,\n",
       "       27915,     1,    28,   289,     0,   167,   238,     4,   211,\n",
       "           7,     0, 10587,  7228,    20,     0, 11268,     1,  1021,\n",
       "           9,    69,   238,     4,  3391,    16,  9453,     7,    11,\n",
       "         228,  1308,   134,   115,    28,     2,  5527,    24,   109,\n",
       "        4379,   186,    33,  3247,  1709,     2,   359,     0,   156,\n",
       "        5102,   553,     1,   309,     0,  3399,  5740,  1491,   260,\n",
       "           5,   386,    58,    16,  5527,    13,   139,   321,    84,\n",
       "         143,   238,     4,     7, 27915,  2735,   149,     7,     0,\n",
       "       27915,  2574,    19,    12,  7067,   803,     0,   216,   166,\n",
       "          32,     3,    50,  5201,   218,    53,   111,   186,    59,\n",
       "           7,    10,    20,     1,    45,    48,    13,   100,   324,\n",
       "           8,    60,    27,    34,  6747,  1476,    20,     1,  3848,\n",
       "           7,     0,  1881,  1165,    34,   810,    15,   114,   196,\n",
       "         869,    11,   331,    19,    12,  7067,    78,    91,    50,\n",
       "         351,     4,     0,   156,  5102, 10088,     2,     0,   787,\n",
       "        6641,     1,    47,     4,     0,   167,  2512,     4,   644,\n",
       "          24,   294,   338,     7,    21,     0, 10088,     1,     7,\n",
       "         187,  3848,     7,     0,  1881,     6,    43,    82,    22,\n",
       "       24748,    52,    35,    63,   140,     7,    65,   199,   114,\n",
       "         587,     1,    53,   716,   171,     5,    21,     0, 10088,\n",
       "          19,    12,  3848,     7,     0,  1881,   121,   307,  2584,\n",
       "          37,    13,  3375,    16,     8,     1,   149,   135,   412,\n",
       "       13418,  7067,    76,    47,   533,   156,  5102,  5048,    82,\n",
       "           7,     0,   390,   542,     1])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_test[0,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.95367777]], dtype=float32)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict(X_test[0][None,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2500, 500)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 500)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_test[0][None,:].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.88365299]], dtype=float32)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentence = \"donald trump . \".lower()\n",
    "sentence_num = [word2num[w] if w in word2num else word2num['<Other>'] for w in sentence.split()]\n",
    "sentence_num = [word2num['<PAD>']]*(500-len(sentence_num)) + sentence_num\n",
    "sentence_num = np.array(sentence_num)\n",
    "model.predict(sentence_num[None,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2000"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word2num['washington']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.01486305,  0.00178486,  0.02659723,  0.03631897, -0.05354311,\n",
       "        0.04314227, -0.03424839, -0.03875845, -0.02021969, -0.01539999,\n",
       "        0.03420364, -0.00446477, -0.0043705 , -0.05768508, -0.02457126,\n",
       "        0.05868063, -0.04874371, -0.05814709, -0.02626852, -0.01279875,\n",
       "        0.0231608 , -0.03791287, -0.01536598, -0.03500497,  0.01831801,\n",
       "       -0.01325543,  0.00514135, -0.03801096, -0.02921454,  0.00427667,\n",
       "       -0.02751204,  0.03453749], dtype=float32)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_weights()[0][word2num['trump']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.0087095 , -0.04241782,  0.01230696, -0.03714038, -0.06135994,\n",
       "        0.00760195,  0.02777364,  0.07299132, -0.06609873, -0.02097359,\n",
       "        0.01137068,  0.01731196,  0.03382057, -0.0036195 , -0.0306195 ,\n",
       "        0.05178453, -0.01523854,  0.00716526, -0.0335728 , -0.07671038,\n",
       "        0.02215168,  0.00062636, -0.00311515, -0.07438745, -0.00466766,\n",
       "       -0.00450422, -0.03426327,  0.06381596, -0.01512961, -0.02493496,\n",
       "        0.07057063,  0.01613435], dtype=float32)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_weights()[0][word2num['hitler']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.02137261, -0.01078706,  0.05461324,  0.0603575 , -0.01319825,\n",
       "        0.02483638, -0.03697665, -0.00418437,  0.00427216,  0.03785726,\n",
       "       -0.01167645, -0.00794956,  0.01235597,  0.02411322,  0.01329259,\n",
       "       -0.031885  , -0.04764228,  0.03314191,  0.04316465,  0.01771139,\n",
       "       -0.03431438, -0.0459654 ,  0.00933751,  0.04778398,  0.0531096 ,\n",
       "       -0.06432696, -0.00759748, -0.06147803,  0.02870869,  0.03668656,\n",
       "        0.01724944, -0.05900047], dtype=float32)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_weights()[0][word2num['<PAD>']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(27917, 32)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_weights()[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
